"""
Phase 1: Data Assembly for Capital Deepening Analysis
=====================================================
Combines multilateral panel, Penn World Table capital/TFP data,
bilateral gravity-predicted inflows (instruments), and WDI governance.
Output: capital_deepening/data/processed/deepening_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"
GRAVITY_DIR = ROOT_DIR / "gravity_bilateral"
CAUSAL_DIR = ROOT_DIR / "causal_identification"

OUT_DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_DATA.mkdir(parents=True, exist_ok=True)
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def load_full_panel():
    """Load multilateral panel, filter 1990-2024."""
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    fp = fp[(fp['year'] >= 1990) & (fp['year'] <= 2024)].copy()
    print(f"Full panel: {len(fp)} obs, {fp['iso3'].nunique()} countries, "
          f"{fp['year'].min()}-{fp['year'].max()}")
    return fp


def load_pwt():
    """Extract capital, TFP, employment from Penn World Table 10.01."""
    pwt_path = MULTILATERAL_DIR / "data" / "raw" / "pwt1001.xlsx"
    if not pwt_path.exists():
        print(f"WARNING: PWT file not found at {pwt_path}")
        return pd.DataFrame()

    pwt = pd.read_excel(pwt_path, sheet_name='Data')
    print(f"PWT raw: {len(pwt)} obs, columns: {list(pwt.columns[:20])}...")

    # Identify country code column
    cc_col = None
    for candidate in ['countrycode', 'country_code', 'iso3', 'ISO3']:
        if candidate in pwt.columns:
            cc_col = candidate
            break
    if cc_col is None:
        # Try to find it
        for c in pwt.columns:
            if pwt[c].dtype == 'object' and pwt[c].str.len().mean() == 3:
                cc_col = c
                break
    if cc_col is None:
        print("ERROR: Cannot find country code column in PWT")
        print(f"Columns: {list(pwt.columns)}")
        return pd.DataFrame()

    keep_cols = ['year']
    pwt_vars = {
        'cn': 'cn',           # Capital stock at constant national prices (millions)
        'rnna': 'rnna',       # Capital stock at constant 2017 national prices
        'emp': 'emp',         # Employment (millions)
        'ctfp': 'ctfp',       # TFP at current PPPs
        'rtfpna': 'rtfpna',   # TFP at constant national prices
        'delta': 'delta',     # Depreciation rate
        'labsh': 'labsh',     # Labor share
        'csh_i': 'csh_i',     # Investment share of GDP
        'irr': 'irr',         # Internal rate of return
        'rgdpo': 'rgdpo',     # Output-side real GDP
        'hc': 'hc',           # Human capital index
    }
    available = {}
    for var, rename in pwt_vars.items():
        if var in pwt.columns:
            available[var] = rename
            keep_cols.append(var)
        else:
            print(f"  PWT variable '{var}' not found")

    pwt_sub = pwt[[cc_col] + keep_cols].copy()
    pwt_sub = pwt_sub.rename(columns={cc_col: 'iso3'})
    # Rename PWT vars to avoid conflicts (prefix pwt_)
    for var, rename in available.items():
        if var != rename:
            pwt_sub = pwt_sub.rename(columns={var: rename})

    pwt_sub = pwt_sub[(pwt_sub['year'] >= 1990) & (pwt_sub['year'] <= 2024)]
    print(f"PWT filtered: {len(pwt_sub)} obs, {pwt_sub['iso3'].nunique()} countries")
    return pwt_sub


def compute_pwt_derived(df):
    """Compute capital deepening, TFP growth, MPK proxy from PWT variables."""
    # Use cn (capital stock at constant national prices) for K/L calculations
    # Prefer rnna (constant 2017 prices) for cross-country comparability
    k_var = 'rnna' if 'rnna' in df.columns else 'cn' if 'cn' in df.columns else None
    if k_var is not None and 'emp' in df.columns:
        df['capital_per_worker'] = df[k_var] / df['emp'].clip(lower=1e-6)
        df['log_kl'] = np.log(df['capital_per_worker'].clip(lower=1e-6))
        print(f"Using '{k_var}' for capital stock")
    else:
        print(f"WARNING: capital stock or emp missing, cannot compute capital_per_worker")
        return df

    # Capital-output ratio
    if 'rgdpo' in df.columns and k_var is not None:
        df['capital_output_ratio'] = df[k_var] / df['rgdpo'].clip(lower=1e-6)

    # MPK proxy = (1 - labsh) * rgdpo / K
    if 'labsh' in df.columns and 'rgdpo' in df.columns and k_var is not None:
        df['mpk_proxy'] = (1 - df['labsh']) * df['rgdpo'] / df[k_var].clip(lower=1e-6)

    # Sort for proper differencing
    df = df.sort_values(['iso3', 'year']).reset_index(drop=True)

    # Growth rates (first differences of logs, within country)
    df['delta_log_kl'] = df.groupby('iso3')['log_kl'].diff()
    if 'ctfp' in df.columns:
        df['log_ctfp'] = np.log(df['ctfp'].clip(lower=1e-6))
        df['delta_log_tfp'] = df.groupby('iso3')['log_ctfp'].diff()
    if 'rtfpna' in df.columns:
        df['log_rtfpna'] = np.log(df['rtfpna'].clip(lower=1e-6))
        df['delta_log_rtfpna'] = df.groupby('iso3')['log_rtfpna'].diff()

    return df


def load_gravity_coefficients():
    """Load Model 2c coefficients from gravity results."""
    grav = pd.read_csv(GRAVITY_DIR / "output" / "tables" / "gravity_results.csv")
    model_2c = grav[grav['model'] == '2c: Gravity + Demographics + KAOPEN interactions'].copy()

    coeffs = {}
    for _, row in model_2c.iterrows():
        var = row['variable']
        if var.startswith('_'):
            continue
        coeffs[var] = row['coefficient']

    print(f"Model 2c coefficients: {coeffs}")
    return coeffs


def construct_gravity_instruments(coeffs):
    """
    Build gravity-predicted bilateral inflows using Model 2c coefficients.
    Decompose into full predicted and demographic-component-only predicted.
    Aggregate by recipient-year.
    """
    bp_path = GRAVITY_DIR / "data" / "processed" / "bilateral_panel.csv"
    bp = pd.read_csv(bp_path)
    print(f"Bilateral panel: {len(bp)} obs, {bp['iso_d'].nunique()} destinations")

    # Variables used in Model 2c
    gravity_vars = ['log_dist', 'contiguity', 'common_lang_official',
                    'colonial_ties', 'log_gdp_product']
    demo_vars = ['dZ_1', 'dZ_2', 'dZ_3']
    interaction_vars = ['kaopen_j', 'dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j',
                        'dZ_3_x_kaopen_j']

    all_vars = gravity_vars + demo_vars + interaction_vars

    # Check all variables present
    missing = [v for v in all_vars if v not in bp.columns]
    if missing:
        print(f"WARNING: Missing bilateral variables: {missing}")

    # Drop rows with missing values in key variables
    valid = bp.dropna(subset=[v for v in all_vars if v in bp.columns]).copy()
    print(f"Valid bilateral obs for prediction: {len(valid)}")

    # Compute full fitted value: sum of coeff * variable for all terms
    valid['predicted_full'] = 0.0
    valid['predicted_demo'] = 0.0  # demographic component only

    for var, coef in coeffs.items():
        if var in valid.columns:
            valid['predicted_full'] += coef * valid[var]
            # Demographic component: dZ terms and their KAOPEN interactions
            if var in demo_vars or var in ['dZ_1_x_kaopen_j', 'dZ_2_x_kaopen_j',
                                            'dZ_3_x_kaopen_j']:
                valid['predicted_demo'] += coef * valid[var]

    # Convert from log-level predictions to levels (exp)
    valid['predicted_full_level'] = np.exp(valid['predicted_full'])
    valid['predicted_demo_level'] = np.exp(valid['predicted_demo'])

    # Aggregate by destination (recipient) country-year
    agg_actual = valid.groupby(['iso_d', 'year']).agg(
        total_portfolio_inflows=('portfolio_total', 'sum'),
        total_fdi_inflows=('fdi_outward', 'sum'),
        n_reporters=('iso_o', 'nunique'),
    ).reset_index()

    agg_predicted = valid.groupby(['iso_d', 'year']).agg(
        predicted_total_inflows=('predicted_full_level', 'sum'),
        predicted_demo_inflows=('predicted_demo_level', 'sum'),
    ).reset_index()

    instruments = agg_actual.merge(agg_predicted, on=['iso_d', 'year'], how='outer')
    instruments = instruments.rename(columns={'iso_d': 'iso3'})

    # Log transforms
    for col in ['total_portfolio_inflows', 'total_fdi_inflows',
                'predicted_total_inflows', 'predicted_demo_inflows']:
        instruments[f'log_{col}'] = np.log(instruments[col].clip(lower=1e-6))

    print(f"Instruments: {len(instruments)} recipient-year obs, "
          f"{instruments['iso3'].nunique()} countries")
    return instruments


def load_wdi_governance():
    """Load WDI governance and education variables."""
    wdi_path = CAUSAL_DIR / "data" / "raw" / "wdi_additional.csv"
    if not wdi_path.exists():
        print(f"WARNING: WDI file not found at {wdi_path}")
        return pd.DataFrame()

    wdi = pd.read_csv(wdi_path)
    keep = ['iso3', 'year']
    for col in ['rule_of_law', 'regulatory_quality', 'control_corruption',
                'tertiary_enrollment', 'govt_effectiveness']:
        if col in wdi.columns:
            keep.append(col)
    wdi_sub = wdi[keep].copy()
    wdi_sub = wdi_sub[(wdi_sub['year'] >= 1990) & (wdi_sub['year'] <= 2024)]
    print(f"WDI governance: {len(wdi_sub)} obs")
    return wdi_sub


def create_summary_stats(df):
    """Generate summary statistics table."""
    key_vars = [
        'ca_gdp', 'gross_fixed_investment_gdp', 'gross_savings_gdp',
        'nfa_gdp', 'rgdp_growth', 'kaopen',
        'Z_1', 'Z_2', 'Z_3', 'old_dep', 'life_expectancy',
        'capital_per_worker', 'delta_log_kl', 'delta_log_tfp',
        'mpk_proxy', 'capital_output_ratio',
        'log_predicted_total_inflows', 'log_predicted_demo_inflows',
        'total_portfolio_inflows', 'total_fdi_inflows',
        'rule_of_law', 'hc',
    ]
    available = [v for v in key_vars if v in df.columns]
    stats = df[available].describe().T
    stats['non_missing'] = df[available].notna().sum()
    stats = stats[['non_missing', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]
    stats.columns = ['N', 'Mean', 'Std Dev', 'Min', 'P25', 'Median', 'P75', 'Max']
    return stats.round(4)


def main():
    print("=" * 70)
    print("PHASE 1: DATA ASSEMBLY FOR CAPITAL DEEPENING")
    print("=" * 70)

    # Step 1: Full panel
    print("\n--- Step 1: Loading full panel ---")
    fp = load_full_panel()

    # Step 2: PWT
    print("\n--- Step 2: Loading Penn World Table ---")
    pwt = load_pwt()

    # Step 3: Merge PWT and compute derived variables
    print("\n--- Step 3: Merging PWT and computing derived variables ---")
    if len(pwt) > 0:
        # Merge PWT into full panel
        # Avoid column conflicts: drop PWT 'hc' if full_panel already has 'human_capital'
        pwt_merge_cols = ['iso3', 'year'] + [c for c in pwt.columns
                                              if c not in ['iso3', 'year']]
        merged = fp.merge(pwt[pwt_merge_cols], on=['iso3', 'year'], how='left')
    else:
        merged = fp.copy()

    # Compute derived capital/TFP variables
    merged = compute_pwt_derived(merged)
    n_kl = merged['delta_log_kl'].notna().sum()
    print(f"Observations with capital deepening data: {n_kl}")

    # Step 4-5: Gravity instruments
    print("\n--- Step 4-5: Constructing gravity-predicted instruments ---")
    coeffs = load_gravity_coefficients()
    instruments = construct_gravity_instruments(coeffs)

    # Merge instruments
    merged = merged.merge(instruments, on=['iso3', 'year'], how='left')
    n_inst = merged['predicted_demo_inflows'].notna().sum()
    print(f"Observations with bilateral instruments: {n_inst}")

    # Step 6: WDI governance
    print("\n--- Step 6: Loading WDI governance ---")
    wdi = load_wdi_governance()
    if len(wdi) > 0:
        # Only keep columns not already in merged (plus merge keys)
        new_cols = [c for c in wdi.columns if c not in merged.columns]
        if new_cols:
            wdi_to_merge = wdi[['iso3', 'year'] + new_cols].copy()
            merged = merged.merge(wdi_to_merge, on=['iso3', 'year'], how='left')

    # Step 7: Save
    print("\n--- Step 7: Saving deepening_panel.csv ---")
    merged = merged.sort_values(['iso3', 'year']).reset_index(drop=True)
    merged.to_csv(OUT_DATA / "deepening_panel.csv", index=False)
    print(f"Saved: {len(merged)} obs, {merged['iso3'].nunique()} countries, "
          f"{len(merged.columns)} columns")

    # Step 8: Summary statistics
    print("\n--- Step 8: Summary statistics ---")
    stats = create_summary_stats(merged)
    stats.to_csv(OUT_TABLES / "summary_statistics.csv")

    # Also save as markdown
    with open(OUT_TABLES / "summary_statistics.md", 'w') as f:
        f.write("# Summary Statistics: Capital Deepening Panel\n\n")
        f.write(stats.to_markdown())
        f.write("\n")
    print(stats)

    # Verification
    print("\n" + "=" * 70)
    print("VERIFICATION")
    print("=" * 70)
    key_checks = {
        'ck': 'PWT capital stock',
        'ctfp': 'PWT TFP',
        'delta_log_kl': 'Capital deepening growth',
        'delta_log_tfp': 'TFP growth',
        'predicted_demo_inflows': 'Gravity-predicted demographic inflows',
        'total_portfolio_inflows': 'Actual portfolio inflows',
        'rule_of_law': 'WDI governance',
        'mpk_proxy': 'MPK proxy',
    }
    for var, label in key_checks.items():
        if var in merged.columns:
            n = merged[var].notna().sum()
            print(f"  {label}: {n} non-missing obs")
        else:
            print(f"  {label}: MISSING from panel")

    print(f"\nTotal panel: {len(merged)} obs × {len(merged.columns)} vars")
    print("Phase 1 complete.")
    return merged


if __name__ == '__main__':
    main()
